{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Introduction to MONAI\n", "\n", "MONAI is a deep learning library for medical imaging research. It provides a comprehensive set of tools and utilities for building, training, and evaluating deep learning models in the medical imaging domain.\n", "\n", "## Key Features\n", "\n", "1. **Data Loading and Preprocessing**: MONAI offers a wide range of data loading and preprocessing functionalities, including support for various image formats, normalization, and augmentation techniques.\n", "2. **Transformations**: It includes a rich set of transformations for data augmentation, normalization, and other preprocessing steps.\n", "3. **Model Building**: MONAI provides a modular approach to building deep learning models, including popular architectures like U-Net, V-Net, and Residual Networks.\n", "4. **Training and Evaluation**: It offers tools for training models, including loss functions, optimizers, and evaluation metrics, making it easier to develop and deploy medical image segmentation models.\n", "5. **Integration with PyTorch**: MONAI is built on top of PyTorch, allowing for seamless integration with the PyTorch ecosystem and leveraging its powerful ecosystem of tools and libraries.\n", "\n", "## Installation\n", "\n", "```bash\n", "pip install monai\n", "```\n", "\n", "We have already installed MONAI in your DL_labs_GPU environment. We will do the same as we did in PyTorch in the notebook \"Segmentation and UNET\". The idea is you will see the difference between the two frameworks and how MONAI may simplify the process of building and training deep learning models for medical imaging tasks." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# Some imports\n", "\n", "import monai\n", "import torch\n", "from torch.utils.data import DataLoader\n", "from monai.transforms import (\n", " EnsureChannelFirstd,\n", " AsDiscreted,\n", " Compose,\n", " LoadImaged,\n", " Orientationd,\n", " Randomizable,\n", " Resized,\n", " ScaleIntensityd,\n", " Spacingd,\n", " EnsureTyped,\n", " Lambda\n", ")\n", "import monai.metrics as metrics\n", "import os\n", "import tempfile\n", "from utils.decathlon_dataset import get_decathlon_dataloader\n", "from utils.unet import UNET\n", "from utils.train import train\n", "\n", "import matplotlib.pyplot as plt\n" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2024-09-18 15:23:41,224 - INFO - Verified 'Task04_Hippocampus.tar', md5: 9d24dba78a72977dbd1d2e110310f31b.\n", "2024-09-18 15:23:41,225 - INFO - File exists: cm2003/datasets/Task04_Hippocampus.tar, skipped downloading.\n", "2024-09-18 15:23:41,227 - INFO - Non-empty folder exists in cm2003/datasets/Task04_Hippocampus, skipped extracting.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Loading dataset: 100%|██████████| 208/208 [00:04<00:00, 49.23it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "2024-09-18 15:23:45,576 - INFO - Verified 'Task04_Hippocampus.tar', md5: 9d24dba78a72977dbd1d2e110310f31b.\n", "2024-09-18 15:23:45,577 - INFO - File exists: cm2003/datasets/Task04_Hippocampus.tar, skipped downloading.\n", "2024-09-18 15:23:45,578 - INFO - Non-empty folder exists in cm2003/datasets/Task04_Hippocampus, skipped extracting.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Loading dataset: 100%|██████████| 52/52 [00:01<00:00, 50.07it/s]\n" ] } ], "source": [ "\n", "root_dir = './utils/datasets'\n", "task = \"Task04_Hippocampus\"\n", "\n", "train_loader = get_decathlon_dataloader(root_dir, task, \"training\", batch_size=4, num_workers=2, shuffle=True)\n", "val_loader = get_decathlon_dataloader(root_dir, task, \"validation\", batch_size=4, num_workers=2, shuffle=False)\n", "\n", "\n" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# Load UNET model from MONAI\n", "\n", "model = monai.networks.nets.UNet(\n", " spatial_dims=2,\n", " in_channels=1,\n", " out_channels=3,\n", " channels=(16, 32, 64, 128, 256),\n", " strides=(2, 2, 2, 2),\n", " num_res_units=2,\n", ")\n", "\n", "# Define loss function and optimizer\n", "loss_fn = torch.nn.CrossEntropyLoss()\n", "optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)\n", "\n", "# Define device\n", "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Implement the training loop for the model with the same structure as for PyTorch. You may even make use of the train function we wrote in PyTorch in `utils/train.py`." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/20, Loss: 0.9478\n", "Epoch 2/20, Loss: 0.4815\n", "Epoch 3/20, Loss: 0.2349\n", "Epoch 4/20, Loss: 0.1539\n", "Epoch 5/20, Loss: 0.1230\n", "Epoch 6/20, Loss: 0.1069\n", "Epoch 7/20, Loss: 0.0975\n", "Epoch 8/20, Loss: 0.0881\n", "Epoch 9/20, Loss: 0.0805\n", "Epoch 10/20, Loss: 0.0740\n", "Epoch 11/20, Loss: 0.0672\n", "Epoch 12/20, Loss: 0.0613\n", "Epoch 13/20, Loss: 0.0565\n", "Epoch 14/20, Loss: 0.0531\n", "Epoch 15/20, Loss: 0.0497\n", "Epoch 16/20, Loss: 0.0469\n", "Epoch 17/20, Loss: 0.0438\n", "Epoch 18/20, Loss: 0.0417\n", "Epoch 19/20, Loss: 0.0396\n", "Epoch 20/20, Loss: 0.0373\n" ] } ], "source": [ "from utils.train import train\n", "# Your code here\n", "trained_model = None\n", "\n", "trained_model = train(model, train_loader, loss_fn, optimizer, device, epochs=20)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now let's evaluate the model on the validation set with the metrics in MONAI. We give the example with Hausdorff Distance but you should try to find other metrics in the documentation: [https://docs.monai.io/en/stable/metrics.html](https://docs.monai.io/en/stable/metrics.html) appropriate for the segmentation task. We can interpret the Hausdorff Distance as the maximum difference between the predicted segmentation boundary and the true boundary for 95% of the points. In this task, we are trying to segment the hippocampus, which is a small structure and the unit of the distance is millimeters." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/maia-user/.conda/envs/DL_labs_GPU/lib/python3.12/site-packages/monai/metrics/utils.py:329: UserWarning: the ground truth of class 0 is all 0, this may result in nan/inf distance.\n", " warnings.warn(\n", "/home/maia-user/.conda/envs/DL_labs_GPU/lib/python3.12/site-packages/monai/metrics/utils.py:334: UserWarning: the prediction of class 0 is all 0, this may result in nan/inf distance.\n", " warnings.warn(\n", "/home/maia-user/.conda/envs/DL_labs_GPU/lib/python3.12/site-packages/monai/metrics/utils.py:329: UserWarning: the ground truth of class 1 is all 0, this may result in nan/inf distance.\n", " warnings.warn(\n", "/home/maia-user/.conda/envs/DL_labs_GPU/lib/python3.12/site-packages/monai/metrics/utils.py:334: UserWarning: the prediction of class 1 is all 0, this may result in nan/inf distance.\n", " warnings.warn(\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "95th percentile Hausdorff distance: 2.5270\n" ] } ], "source": [ "import torch\n", "from monai.metrics import HausdorffDistanceMetric\n", "\n", "hausdorff_distance = HausdorffDistanceMetric(include_background=False, reduction=\"mean\", percentile=95)\n", "\n", "for val_data in val_loader:\n", " val_inputs, val_labels = val_data[\"image\"].to(device), val_data[\"label\"].to(device)\n", " val_outputs = torch.nn.functional.one_hot(trained_model(val_inputs).argmax(dim=1), num_classes=3).permute(0, 3, 1, 2)\n", " val_labels = torch.nn.functional.one_hot(val_labels.long(), num_classes=3).squeeze(1).permute(0, 3, 1, 2)\n", " \n", " # Identify valid classes\n", " valid_classes = torch.logical_and(torch.any(val_labels, dim=(0, 2, 3)), torch.any(val_outputs, dim=(0, 2, 3)))\n", " \n", " # Only compute Hausdorff distance for valid classes\n", " if torch.any(valid_classes):\n", " hausdorff_distance(val_outputs[:, valid_classes], val_labels[:, valid_classes])\n", " else:\n", " print(\"No valid classes found for Hausdorff distance calculation.\")\n", "\n", "result = hausdorff_distance.aggregate().item()\n", "print(f'95th percentile Hausdorff distance: {result:.4f}')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Evaluation for two other metrics\n", "\n", "**Implement evaluation for two other metrics in MONAI and comment on the results.**" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we have trained and evaluated the model, let's change the loss function to DiceLoss and train the model again which is another loss function commonly used in medical image segmentation. Look at the documentation of MONAI to find out how to do this: [https://docs.monai.io/en/stable/losses.html](https://docs.monai.io/en/stable/losses.html)\n", "\n", "After you have trained the model, you can evaluate the model on the validation set with the new loss function." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "# Your code here" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Optional part\n", "\n", "MONAI also provides implementations of many networks apart from UNET. You can find them in the `monai.networks` module, see [https://docs.monai.io/en/stable/networks.html](https://docs.monai.io/en/stable/networks.html). Try to use one of them for the Hippocampus dataset and evaluate the performance. Compare the results with UNET and DiceLoss. " ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "# Your code here" ] } ], "metadata": { "kernelspec": { "display_name": "DL_labs_GPU", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.4" } }, "nbformat": 4, "nbformat_minor": 2 }